Skip to content

Conversation

@hfrick
Copy link
Member

@hfrick hfrick commented Jan 25, 2023

closes #858

This PR moves where we set s, glmnet's argument for the penalty value, from inside the relevant multi_predict() method into the relevant predict_raw() method.

This means that now the default penalty value specified in the parsnip spec will be used when penalty = NULL, also for for type = "raw" (aka the original issue).

This also means that the call stack for multi_predict() for logistic and multinomial regression now follows that of linear regression and what happens in the code is what was laid out in the comments, e.g.,

parsnip/R/logistic_reg.R

Lines 209 to 234 in 2249cbb

# ------------------------------------------------------------------------------
# glmnet call stack for logistic regression using `predict` when object has
# classes "_lognet" and "model_fit" (for class predictions):
#
# predict()
# predict._lognet(penalty = NULL) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_class()
# predict_class._lognet()
# predict_class.model_fit()
# predict.lognet()
# glmnet call stack for logistic regression using `multi_predict` when object has
# classes "_lognet" and "model_fit" (for class predictions):
#
# multi_predict()
# multi_predict._lognet(penalty = NULL)
# predict._lognet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._lognet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.lognet()
# ------------------------------------------------------------------------------

The tests are in extratests: tidymodels/extratests#72

library(parsnip)

data(lending_club, package = "modeldata")

lr_spec <- logistic_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(lr_spec, Class ~ log(funded_amnt) + int_rate + term,
             data = lending_club)
predict(f_fit, lending_club[1:5, ], type = "raw")
#>         s1
#> 1 2.894019
#> 2 2.894019
#> 3 2.894019
#> 4 2.894019
#> 5 2.894019

data("penguins", package = "modeldata")
penguins <- tidyr::drop_na(penguins)

mr_spec <- multinom_reg(penalty = 0.123) %>% set_engine("glmnet")
f_fit <- fit(mr_spec, species ~ island + bill_length_mm + bill_depth_mm,
             data = penguins)

predict(f_fit, penguins[1:5,], type = "raw")
#> , , 1
#> 
#>      Adelie Chinstrap    Gentoo
#> 1 -5.131312 -7.544642 -7.542310
#> 2 -5.229012 -7.544642 -6.800678
#> 3 -5.424412 -7.544642 -7.142970
#> 4 -4.545112 -7.544642 -7.884602
#> 5 -5.180162 -7.544642 -8.626234

Created on 2023-01-25 with reprex v2.0.2

so that it:
- also gets applied in `predict(type = "raw")`
- structure follows that of `linear_reg()`, which is also laid out in the comments
@topepo topepo merged commit a517c87 into main Feb 6, 2023
@topepo topepo deleted the glmnet-penalty-raw branch February 6, 2023 14:57
@github-actions
Copy link
Contributor

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Feb 21, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

glmnet engines should always respect the penalty value set in the spec

3 participants